
# Correct P val in Delta table --------------------------------------------
correct_pval_delta <- function(Delta.OneSample, make_paper_p.adjust.method) {
  
  Delta.OneSample <- Delta.OneSample %>% ungroup() %>%
    mutate(Corrected.OneSamplTtest.pval = 
             p.adjust(OneSamplTtest.pval, 
                      method = make_paper_p.adjust.method)) %>%
    mutate(Significant = if_else(Corrected.OneSamplTtest.pval < 0.05, 
                                 "Significant", 
                                 "Not_significant"))
  
  
  return(Delta.OneSample)
  
  
}

correct_pval_interaction <- function(Interaction.node, make_paper_p.adjust.method) {
  
  Interaction.node <- Interaction.node %>%
    group_by(Stat) %>%
    mutate(corrected.p.value = p.adjust(p.value, method = make_paper_p.adjust.method)) %>%
    ungroup() %>%
    mutate(Significant = if_else(corrected.p.value < 0.05, "Significant", "Not_significant")) %>%
    mutate(From = From_both)
  
  return(Interaction.node)
}

# Make full table ---------------------------------------------------------

make_full_table <- function(Data.grouped, Delta.OneSample) {
  
  Delta.grouped <- Data.grouped %>%
    gather(key = "measure", value = "value", -From, -To, -Impulsivity, -Treatment) %>%
    spread(key = Impulsivity, value = value) %>%
    mutate(Delta = high - low) %>%
    select(-high, -low) %>%
    spread(key = measure, value = Delta)
  
  Delta.grouped.full <- Delta.grouped %>%
    rename(From2 = From, To2 = To) %>%
    rename(From = To2, To = From2) %>%
    bind_rows(Delta.grouped, .) %>%
    mutate(From = gsub(pattern = " L| R", replacement = "", From)) %>%
    select(From, Treatment, To, correlation.rho, FisherZ)
  
  Delta.grouped.full <- left_join(Delta.grouped.full, Delta.OneSample, by = c("From", "Treatment")) 
  
  return(Delta.grouped.full)
  
}

# Make suppl figure 1 -----------------------------------------------------

make_suppl_figure1 <- function(Delta.OneSample, Delta.grouped.full, Treat, theme_fig1) {
  
  Temp <- Delta.OneSample %>% 
    filter(Treatment == Treat)
  
  Temp <- Temp %>% 
    mutate(From = factor(From, levels = Temp$From[order(-Temp$OneSamplTtest.tval)])) %>%
    arrange(From)
  
  Temp.grouped.mean <- Delta.grouped.full %>% 
    filter(Treatment == Treat) %>%
    group_by(From, Treatment) %>%
    summarise(MeanCor = mean(correlation.rho), SDCor = sd(correlation.rho), NCor = n(), SEMCor = SDCor/NCor,
              y0 = quantile(correlation.rho, 0.10), y25 = quantile(correlation.rho, 0.25), y50 = median(correlation.rho), 
              y75 = quantile(correlation.rho, 0.75), y100 = quantile(correlation.rho, 0.90),
              MeanFish = mean(FisherZ), SDFish = sd(FisherZ), NFish = n(), SEMFish = SDFish/NFish,
              y0f = quantile(FisherZ, 0.10), y25f = quantile(FisherZ, 0.25), y50f = median(FisherZ), 
              y75f = quantile(FisherZ, 0.75), y100f = quantile(FisherZ, 0.90)) %>%
    ungroup()
  
  Temp.grouped.mean <- left_join(Temp.grouped.mean, Temp, by = c("From", "Treatment")) %>% 
    mutate(From = factor(From, levels = Temp$From[order(-Temp$OneSamplTtest.tval)])) %>%
    arrange(From) %>%
    mutate(N = n(), Color = rev(colorz),
           Color = if_else(Significant == "Not_significant", grey, Color)) %>%
    mutate(p.text = case_when(
      Corrected.OneSamplTtest.pval >= 0.05 ~ "",
      Corrected.OneSamplTtest.pval < 0.05 & Corrected.OneSamplTtest.pval >= 0.01 ~ "*",
      Corrected.OneSamplTtest.pval < 0.01 & Corrected.OneSamplTtest.pval >= 0.001 ~ "**",
      Corrected.OneSamplTtest.pval < 0.001 & Corrected.OneSamplTtest.pval >= 0.0001 ~ "***",
      Corrected.OneSamplTtest.pval < 0.0001 ~ "****"
    ))
  
  plot <- ggplot(data = Temp.grouped.mean, 
                 aes(x = From, fill = Color)) + 
    geom_boxplot(aes(ymin = y0f, lower = y25f, middle = y50f, 
                     upper = y75f, ymax = y100f), stat = "identity") + 
    geom_point(aes(y = MeanFish), color = "black", size = 3) +
    geom_text(aes(y = y100f * 1.1, label = p.text), fontface = "bold", size = 5) +
    geom_hline(yintercept = 0, linetype = "solid", color = "black") +
    scale_fill_identity() +
    xlab("") +
    ylab("Fisher difference: high - low") + 
    theme_fig1 +
    theme(axis.text.x  = element_text(size = 22, angle = 0, hjust = 0.6, 
                                      face = "bold", vjust = 0.4),
          axis.title.y = element_text(face = "bold", colour = "black", 
                                      size = 28, vjust = 0.4),
          axis.title.x = element_text(face = "bold", colour = "black", 
                                      size = 28, vjust = 0.4)) +
    coord_flip()
    #coord_fixed(ratio = 100/(1.35))
  
  return(plot)
}



# Suppl fig 2 -------------------------------------------------------------

make_suppl_figure2 <- function(Data.grouped.impulse, Data.impulse.interaction,
                               vehicle.grey, cmp11.green) {
  
  Node.full <-  Data.grouped.impulse %>%
    rename(From2 = From, To2 = To) %>%
    rename(From = To2, To = From2) %>%
    bind_rows(Data.grouped.impulse, .) %>%
    mutate(From = gsub(pattern = " L| R", replacement = "", From)) %>%
    select(From, To, Treatment, Impulsivity, correlation.rho, FisherZ)
  
  Node.order <- Data.impulse.interaction %>% 
    filter(Stat == "Impulsivity_Treatment") %>% 
    arrange(desc(`F value`)) %>% #sort largest/smallest
    pull(From) 
  
  Node.mean <- Node.full %>%
    group_by(From, Treatment, Impulsivity) %>%
    summarise(MeanCor = mean(correlation.rho), SDCor = sd(correlation.rho), NCor = n(), SEMCor = SDCor/NCor,
              y0 = quantile(correlation.rho, 0.10), y25 = quantile(correlation.rho, 0.25), y50 = median(correlation.rho), 
              y75 = quantile(correlation.rho, 0.75), y100 = quantile(correlation.rho, 0.90),
              MeanFish = mean(FisherZ), SDFish = sd(FisherZ), NFish = n(), SEMFish = SDFish/NFish,
              y0f = quantile(FisherZ, 0.10), y25f = quantile(FisherZ, 0.25), y50f = median(FisherZ), 
              y75f = quantile(FisherZ, 0.75), y100f = quantile(FisherZ, 0.90)) %>%
    ungroup() %>%
    mutate(Treatment = factor(Treatment, levels = rev(Treatments.groups)),
           Impulsivity = factor(Impulsivity, levels = rev(Impulsivity.groups)),
           From = factor(From, levels = Node.order))
  
  plotCor <- ggplot(data = Node.mean, aes(x = From,
                                          fill = Treatment)) + 
    geom_boxplot(aes(ymin = y0f, lower = y25f, middle = y50f, upper = y75f, ymax = y100f),
                 stat = "identity") + 
    geom_point(aes(y = MeanFish), color = "black", position = position_dodge(width = 0.9)) +
    scale_fill_manual(values = c(vehicle.grey, cmp11.green)) +
    theme_bw() + 
    ylab("Correlation") +
    xlab("") +
    theme_fig1 +
    theme(axis.text.x  = element_text(size = 22, angle = 0, hjust = 0.6, 
                                      face = "bold", vjust = 0.4),
          axis.title.y = element_text(face = "bold", colour = "black", 
                                      size = 28, vjust = 0.4),
          axis.title.x = element_text(face = "bold", colour = "black", 
                                      size = 28, vjust = 0.4)) +
    theme(strip.background = element_blank(),
          strip.text.y = element_text(size = 28, colour = "black", face = "bold")) +
    coord_flip() +
    facet_grid(~Impulsivity)
  
  return(plotCor)
}

# Process heatmap colors --------------------------------------------------

filter_heatmap_colors <- function(Labels.data.file, remove.areas) {
  tryCatch({
    
    remove.areas <- str_replace_all(remove.areas, "_", " ")
    
    Labels.data <- read_delim(file = Labels.data.file, delim = ";")
    
    Labels.data <- Labels.data %>%
      filter(!(fMRIName %in% remove.areas))
    
    return(Labels.data)
    
  }, error = function(e) {
    print("ERROR!")
    print(e)
  })
}

process_heatmap_colors <- function(Labels.data, Heatmap.Level.color) {
  tryCatch({
    # Level color is taken from config file
    Heatmap.brains.sorted <- left_join(Labels.data, 
                                       Heatmap.Level.color, 
                                       by = "Parental_brain_region")  %>% 
      mutate(Parental_brain_region = factor(Parental_brain_region, 
                                            levels = Heatmap.brain.levels)) %>%
      arrange(desc(Hemisphere), 
              Parental_brain_region, Color)
    
    return(Heatmap.brains.sorted)
    
  }, error = function(e) {
    print("ERROR!")
    print(e)
  })
}

# Heatmap plot ----------------------------------------------------------------------------------------------------
# Must have a full table to work on! Need to presort Heatmap.brains to contain only the needed areas.

plot_heatmap <- function(Data, measure = "adj.p.value", breaks = NA,
                         color = colorRampPalette(rev(brewer.pal(n = 7, name =
                                                                   "RdYlBu")))(100),
                         plot.name = "xxx",
                         Annotations.data = Heatmap.brains){
  
  # change the color of annotation to what you want: (eg: "navy", "darkgreen")
  Var1        <- Annotations.data$Color.simplyfied %>% 
    unique(.)
  names(Var1) <- Annotations.data$Parental_brain_region %>% 
    unique(.)
  Var2        <- c("navy", "darkgreen")
  names(Var2) <- unique(Annotations.data$Hemisphere)
  anno_colors <- list(Parental_brain_region = Var1, Hemisphere = Var2)
  
  annotation_rows <- Annotations.data %>% select(Parental_brain_region) %>% 
    data.frame(.)
  rownames(annotation_rows) <- Annotations.data$Atlas
  
  annotation_columns <- Annotations.data %>% select(Hemisphere) %>% 
    data.frame(.)
  rownames(annotation_columns) <- Annotations.data$Atlas
  
  # init matrix
  mat.size <- length(unique(Annotations.data$Atlas))
  
  mat <- matrix(NA, nrow = mat.size, ncol = mat.size) 
  
  dimnames(mat) = list(
    Annotations.data$Atlas, # row names 
    Annotations.data$Atlas) # col names
  
  # fill data
  mat[as.matrix(Data[c("From", "To")])] <- Data[[measure]]
  
  plot <- pheatmap(mat,
                   breaks = breaks,
                   color = color,
                   border_color = "grey60",
                   cluster_rows = FALSE, 
                   cluster_cols = FALSE,
                   annotation_col = annotation_columns,
                   annotation_row = annotation_rows,
                   annotation_colors = anno_colors,
                   annotation_names_row = FALSE,
                   annotation_names_col = FALSE,
                   show_rownames = FALSE,
                   show_colnames = FALSE,
                   silent = FALSE)[[4]]
  
  plot$grobs[[4]]$rot <- -270
  plot$grobs[[4]]$just <- "bottom"
  plot$grobs[[4]]$hjust <- 1

  return(plot)
  
}
